{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a Ridge Regression Model on the Diabetes Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook loads the Diabetes dataset from sklearn, splits the data into training and validation sets, trains a Ridge regression model, validates the model on the validation set, and saves the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_diabetes\n",
    "from sklearn.linear_model import Ridge\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.model_selection import train_test_split\n",
    "import joblib\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_data = load_diabetes()\n",
    "\n",
    "df = pd.DataFrame(\n",
    "    data=sample_data.data,\n",
    "    columns=sample_data.feature_names)\n",
    "df['Y'] = sample_data.target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(442, 10)\n"
     ]
    }
   ],
   "source": [
    "print(df.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>sex</th>\n",
       "      <th>bmi</th>\n",
       "      <th>bp</th>\n",
       "      <th>s1</th>\n",
       "      <th>s2</th>\n",
       "      <th>s3</th>\n",
       "      <th>s4</th>\n",
       "      <th>s5</th>\n",
       "      <th>s6</th>\n",
       "      <th>Y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>count</td>\n",
       "      <td>4.420000e+02</td>\n",
       "      <td>4.420000e+02</td>\n",
       "      <td>4.420000e+02</td>\n",
       "      <td>4.420000e+02</td>\n",
       "      <td>4.420000e+02</td>\n",
       "      <td>4.420000e+02</td>\n",
       "      <td>4.420000e+02</td>\n",
       "      <td>4.420000e+02</td>\n",
       "      <td>4.420000e+02</td>\n",
       "      <td>4.420000e+02</td>\n",
       "      <td>442.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>mean</td>\n",
       "      <td>-3.634285e-16</td>\n",
       "      <td>1.308343e-16</td>\n",
       "      <td>-8.045349e-16</td>\n",
       "      <td>1.281655e-16</td>\n",
       "      <td>-8.835316e-17</td>\n",
       "      <td>1.327024e-16</td>\n",
       "      <td>-4.574646e-16</td>\n",
       "      <td>3.777301e-16</td>\n",
       "      <td>-3.830854e-16</td>\n",
       "      <td>-3.412882e-16</td>\n",
       "      <td>152.133484</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>std</td>\n",
       "      <td>4.761905e-02</td>\n",
       "      <td>4.761905e-02</td>\n",
       "      <td>4.761905e-02</td>\n",
       "      <td>4.761905e-02</td>\n",
       "      <td>4.761905e-02</td>\n",
       "      <td>4.761905e-02</td>\n",
       "      <td>4.761905e-02</td>\n",
       "      <td>4.761905e-02</td>\n",
       "      <td>4.761905e-02</td>\n",
       "      <td>4.761905e-02</td>\n",
       "      <td>77.093005</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>min</td>\n",
       "      <td>-1.072256e-01</td>\n",
       "      <td>-4.464164e-02</td>\n",
       "      <td>-9.027530e-02</td>\n",
       "      <td>-1.123996e-01</td>\n",
       "      <td>-1.267807e-01</td>\n",
       "      <td>-1.156131e-01</td>\n",
       "      <td>-1.023071e-01</td>\n",
       "      <td>-7.639450e-02</td>\n",
       "      <td>-1.260974e-01</td>\n",
       "      <td>-1.377672e-01</td>\n",
       "      <td>25.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25%</td>\n",
       "      <td>-3.729927e-02</td>\n",
       "      <td>-4.464164e-02</td>\n",
       "      <td>-3.422907e-02</td>\n",
       "      <td>-3.665645e-02</td>\n",
       "      <td>-3.424784e-02</td>\n",
       "      <td>-3.035840e-02</td>\n",
       "      <td>-3.511716e-02</td>\n",
       "      <td>-3.949338e-02</td>\n",
       "      <td>-3.324879e-02</td>\n",
       "      <td>-3.317903e-02</td>\n",
       "      <td>87.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50%</td>\n",
       "      <td>5.383060e-03</td>\n",
       "      <td>-4.464164e-02</td>\n",
       "      <td>-7.283766e-03</td>\n",
       "      <td>-5.670611e-03</td>\n",
       "      <td>-4.320866e-03</td>\n",
       "      <td>-3.819065e-03</td>\n",
       "      <td>-6.584468e-03</td>\n",
       "      <td>-2.592262e-03</td>\n",
       "      <td>-1.947634e-03</td>\n",
       "      <td>-1.077698e-03</td>\n",
       "      <td>140.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>75%</td>\n",
       "      <td>3.807591e-02</td>\n",
       "      <td>5.068012e-02</td>\n",
       "      <td>3.124802e-02</td>\n",
       "      <td>3.564384e-02</td>\n",
       "      <td>2.835801e-02</td>\n",
       "      <td>2.984439e-02</td>\n",
       "      <td>2.931150e-02</td>\n",
       "      <td>3.430886e-02</td>\n",
       "      <td>3.243323e-02</td>\n",
       "      <td>2.791705e-02</td>\n",
       "      <td>211.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>max</td>\n",
       "      <td>1.107267e-01</td>\n",
       "      <td>5.068012e-02</td>\n",
       "      <td>1.705552e-01</td>\n",
       "      <td>1.320442e-01</td>\n",
       "      <td>1.539137e-01</td>\n",
       "      <td>1.987880e-01</td>\n",
       "      <td>1.811791e-01</td>\n",
       "      <td>1.852344e-01</td>\n",
       "      <td>1.335990e-01</td>\n",
       "      <td>1.356118e-01</td>\n",
       "      <td>346.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                age           sex           bmi            bp            s1  \\\n",
       "count  4.420000e+02  4.420000e+02  4.420000e+02  4.420000e+02  4.420000e+02   \n",
       "mean  -3.634285e-16  1.308343e-16 -8.045349e-16  1.281655e-16 -8.835316e-17   \n",
       "std    4.761905e-02  4.761905e-02  4.761905e-02  4.761905e-02  4.761905e-02   \n",
       "min   -1.072256e-01 -4.464164e-02 -9.027530e-02 -1.123996e-01 -1.267807e-01   \n",
       "25%   -3.729927e-02 -4.464164e-02 -3.422907e-02 -3.665645e-02 -3.424784e-02   \n",
       "50%    5.383060e-03 -4.464164e-02 -7.283766e-03 -5.670611e-03 -4.320866e-03   \n",
       "75%    3.807591e-02  5.068012e-02  3.124802e-02  3.564384e-02  2.835801e-02   \n",
       "max    1.107267e-01  5.068012e-02  1.705552e-01  1.320442e-01  1.539137e-01   \n",
       "\n",
       "                 s2            s3            s4            s5            s6  \\\n",
       "count  4.420000e+02  4.420000e+02  4.420000e+02  4.420000e+02  4.420000e+02   \n",
       "mean   1.327024e-16 -4.574646e-16  3.777301e-16 -3.830854e-16 -3.412882e-16   \n",
       "std    4.761905e-02  4.761905e-02  4.761905e-02  4.761905e-02  4.761905e-02   \n",
       "min   -1.156131e-01 -1.023071e-01 -7.639450e-02 -1.260974e-01 -1.377672e-01   \n",
       "25%   -3.035840e-02 -3.511716e-02 -3.949338e-02 -3.324879e-02 -3.317903e-02   \n",
       "50%   -3.819065e-03 -6.584468e-03 -2.592262e-03 -1.947634e-03 -1.077698e-03   \n",
       "75%    2.984439e-02  2.931150e-02  3.430886e-02  3.243323e-02  2.791705e-02   \n",
       "max    1.987880e-01  1.811791e-01  1.852344e-01  1.335990e-01  1.356118e-01   \n",
       "\n",
       "                Y  \n",
       "count  442.000000  \n",
       "mean   152.133484  \n",
       "std     77.093005  \n",
       "min     25.000000  \n",
       "25%     87.000000  \n",
       "50%    140.500000  \n",
       "75%    211.500000  \n",
       "max    346.000000  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# All data in a single dataframe\n",
    "df.describe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Split Data into Training and Validation Sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = df.drop('Y', axis=1).values\n",
    "y = df['Y'].values\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    X, y, test_size=0.2, random_state=0)\n",
    "data = {\"train\": {\"X\": X_train, \"y\": y_train},\n",
    "        \"test\": {\"X\": X_test, \"y\": y_test}}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Model on Training Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Ridge(alpha=0.5, copy_X=True, fit_intercept=True, max_iter=None,\n",
       "      normalize=False, random_state=None, solver='auto', tol=0.001)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# experiment parameters\n",
    "args = {\n",
    "    \"alpha\": 0.5\n",
    "}\n",
    "\n",
    "reg_model = Ridge(**args)\n",
    "reg_model.fit(data[\"train\"][\"X\"], data[\"train\"][\"y\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validate Model on Validation Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'mse': 3298.9096058070622}\n"
     ]
    }
   ],
   "source": [
    "preds = reg_model.predict(data[\"test\"][\"X\"])\n",
    "mse = mean_squared_error(preds, y_test)\n",
    "metrics = {\"mse\": mse}\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['sklearn_regression_model.pkl']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_name = \"sklearn_regression_model.pkl\"\n",
    "\n",
    "joblib.dump(value=reg, filename=model_name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
